Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Jan 26, 2026

LabelAttentionClassifier.forward() creates label_indices on CPU, causing device mismatch errors when running on GPU.

Changes

  • Create label_indices with explicit device and dtype matching token_embeddings:
# Before
label_indices = torch.arange(self.num_classes).expand(B, -1)

# After  
label_indices = torch.arange(
    self.num_classes, dtype=torch.long, device=token_embeddings.device
).expand(B, -1)

This ensures the embedding lookup operates on tensors from the same device, preventing runtime errors in GPU environments.


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Copilot AI changed the title [WIP] Update cross attention labels implementation based on feedback Fix device placement for label_indices in LabelAttentionClassifier Jan 26, 2026
@meilame-tayebjee meilame-tayebjee marked this pull request as ready for review January 27, 2026 09:22
@meilame-tayebjee meilame-tayebjee merged commit d266572 into 24-add-cross-attention-labels-text Jan 27, 2026
@meilame-tayebjee meilame-tayebjee deleted the copilot/sub-pr-60-again branch January 27, 2026 09:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants